{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Reverse Image Search featuring Impressionist Paintings\n",
    "\n",
    "It's easy to get your own reverse-image search up and running locally with a vector database. We use [Milvus Lite](https://milvus.io/docs/milvus_lite.md) and [PyTorch](https://pytorch.org/) to build a local reverse image search using the Impressionist-Classifier Dataset found on [Kaggle](https://www.kaggle.com/datasets/delayedkarma/impressionist-classifier-data).\n",
    "\n",
    "We start by importing the necessary libraries:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install pymilvus==2.2.5 torch gdown torchvision tqdm milvus"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we download the dataset using `gdown` to download it from a publicly hosted Google Drive and `zipfile` to unzip the images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gdown\n",
    "import zipfile\n",
    "\n",
    "url = 'https://drive.google.com/uc?id=1OYDHLEy992qu5C4C8HV5uDIkOWRTAR1_'\n",
    "output = './paintings.zip'\n",
    "gdown.download(url, output)\n",
    "\n",
    "with zipfile.ZipFile(\"./paintings.zip\",\"r\") as zip_ref:\n",
    "    zip_ref.extractall(\"./paintings\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once we've downloaded all the images, we need to set up some parameters that we'll use later on to work with our vector database."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Collection Setup Parameters\n",
    "COLLECTION_NAME = 'image_search'  # Collection name\n",
    "DIMENSION = 2048  # Embedding vector size in this example\n",
    "\n",
    "# Inference Arguments\n",
    "BATCH_SIZE = 128\n",
    "TOP_K = 3"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's start our local Vector Database instance with Milvus Lite"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from milvus import default_server\n",
    "from pymilvus import connections, utility\n",
    "\n",
    "# (OPTIONAL) Set if you want store all related data to specific location\n",
    "# Default location:\n",
    "#   %APPDATA%/milvus-io/milvus-server on windows\n",
    "#   ~/.milvus-io/milvus-server on linux\n",
    "# default_server.set_base_dir('milvus_data')\n",
    "\n",
    "# (OPTIONAL) if you want cleanup previous data\n",
    "# default_server.cleanup()\n",
    "\n",
    "# Start your milvus server\n",
    "default_server.start()\n",
    "\n",
    "# Now you could connect with localhost and the given port\n",
    "# Port is defined by default_server.listen_port\n",
    "connections.connect(host='127.0.0.1', port=default_server.listen_port)\n",
    "\n",
    "# Check if the server is ready.\n",
    "print(utility.get_server_version())"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With our server up, we're ready to get started. We start by defining the vector database schema and establishing a collection. Each entry in our collection features three fields. First, an `id` for regular querying, next a filepath to identify where the image is stored locally, and lastly, the embedding that we use for similarity search."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pymilvus import FieldSchema, CollectionSchema, DataType, Collection\n",
    "\n",
    "# Check that the collection does not yet exist\n",
    "if utility.has_collection(COLLECTION_NAME):\n",
    "    utility.drop_collection(COLLECTION_NAME)\n",
    "\n",
    "fields = [\n",
    "    FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True),\n",
    "    FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200),\n",
    "    FieldSchema(name='image_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)\n",
    "]\n",
    "schema = CollectionSchema(fields=fields)\n",
    "collection = Collection(name=COLLECTION_NAME, schema=schema)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once we have a connection to our vector database and an established collection, we create an index to search on. For this example, we use an IVF Flat index measured with the L2 Norm and 128 cluster units (`nlist`).\n",
    "\n",
    "Click here to learn more about [vector indexes](https://medium.com/vector-database/how-to-choose-an-index-in-milvus-4f3d15259212)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "index_params = {\n",
    "    \"index_type\": \"IVF_FLAT\",\n",
    "    \"metric_type\": \"L2\",\n",
    "    \"params\": {\"nlist\": 128},\n",
    "}\n",
    "collection.create_index(field_name=\"image_embedding\", index_params=index_params)\n",
    "collection.load()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's get the embeddings. We use the [ResNet50 model from PyTorch](https://pytorch.org/vision/master/models/generated/torchvision.models.resnet50.html) to get the embeddings. Normally, the last layer of the ResNet50 model outputs classifications for a dataset. In our case, we need the embeddings, not the classifications, so we remove the last layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "paths = glob.glob('./paintings/paintings/**/*.jpg', recursive=True)\n",
    "print(len(paths))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# run this before importing th resnet50 model if you run into an SSL certificate URLError\n",
    "# import ssl\n",
    "# ssl._create_default_https_context = ssl._create_unverified_context"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "# Load the embedding model with the last layer removed\n",
    "# the last layer is the classification layer, the embeddings are the layer right before\n",
    "model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)\n",
    "model = torch.nn.Sequential(*(list(model.children())[:-1]))\n",
    "model.eval()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "At this point we have a local Milvus instance running, the data for this example project, and the model we need to get some embeddings. Uploading the data into the vector database as a collection and getting it indexed is the next step. First, we preprocess the data to fit the data that the [ResNet50 Model](https://pytorch.org/vision/master/models/generated/torchvision.models.resnet50.html) was trained on (seen on the bottom of the page). Then we batch the data and \"upload\" it to our vector database.\n",
    "\n",
    "*the second block of this step takes a while to run (about 12 minutes on a 16GB RAM M1 Mac), we are running almost 5000 images through ResNet50, now would be a good time to grab a snack or something to drink :)*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision import transforms\n",
    "from PIL import Image\n",
    "from tqdm import tqdm\n",
    "\n",
    "preprocess = transforms.Compose([\n",
    "    transforms.Resize(256),\n",
    "    transforms.CenterCrop(224),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def embed(data):\n",
    "    with torch.no_grad():\n",
    "        output = model(torch.stack(data[0])).squeeze()\n",
    "        collection.insert([data[1], output.tolist()])\n",
    "\n",
    "data_batch = [[], []]\n",
    "\n",
    "for path in tqdm(paths):\n",
    "    im = Image.open(path).convert('RGB')\n",
    "    data_batch[0].append(preprocess(im))\n",
    "    data_batch[1].append(path)\n",
    "    if len(data_batch[0]) % BATCH_SIZE == 0:\n",
    "        embed(data_batch)\n",
    "        data_batch = [[], []]\n",
    "\n",
    "if len(data_batch[0]) != 0:\n",
    "    embed(data_batch)\n",
    "\n",
    "collection.flush()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With our vector database populated, we are ready to perform a reverse image search. The second block in this section shows how you pick an image or set of images to reverse image search. For this example, we have two search patterns. The first one (commented out) shows how to search the provided test paintings. The second one shows how to reverse image search for a specific painting provided in the training set. In most cases, it's bad practice to use training data while doing any sort of validation, but this particular case confirms that we are returning the same image when performing a reverse image search an on image that exists in our vector database"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time # for measuring how long the reverse image search takes\n",
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# modify `search_paths` to modify which images you reverse search\n",
    "\n",
    "# the below example shows how to search for the most similar results to two images in the test set\n",
    "# search_paths = glob.glob('./paintings/test_paintings/**/*.jpg', recursive=True) \n",
    "\n",
    "# the below example shows how to search for a specific image in the training set\n",
    "# we expect to get the same image back as the \"most similar\" image\n",
    "search_paths = glob.glob('./paintings/paintings/training/training/Degas/213255.jpg', recursive=True)\n",
    "len(search_paths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def embed(data):\n",
    "    with torch.no_grad():\n",
    "        ret = model(torch.stack(data))\n",
    "        if len(ret) > 1:\n",
    "            return ret.squeeze().tolist()\n",
    "        else:\n",
    "            return torch.flatten(ret, start_dim=1).tolist()\n",
    "data_batch = [[], []]\n",
    "\n",
    "for path in search_paths:\n",
    "    im = Image.open(path).convert('RGB')\n",
    "    data_batch[0].append(preprocess(im))\n",
    "    data_batch[1].append(path)\n",
    "\n",
    "embeds = embed(data_batch[0])\n",
    "start = time.time()\n",
    "res = collection.search(embeds, \n",
    "                        anns_field='image_embedding', \n",
    "                        param={\"metric_type\": \"L2\",\n",
    "                               \"params\": {\"nprobe\": 10}}, \n",
    "                        limit=TOP_K, \n",
    "                        output_fields=['filepath'])\n",
    "finish = time.time()\n",
    "print(finish - start)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can plot the images for a visual."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f, axarr = plt.subplots(len(data_batch[1]), TOP_K + 1, figsize=(20, 10), squeeze=False)\n",
    "\n",
    "# print(type(res))\n",
    "\n",
    "for hits_i, hits in enumerate(res):\n",
    "    axarr[hits_i][0].imshow(Image.open(data_batch[1][hits_i]))\n",
    "    axarr[hits_i][0].set_axis_off()\n",
    "    axarr[hits_i][0].set_title('Search Time: ' + str(finish - start))\n",
    "    for hit_i, hit in enumerate(hits):\n",
    "        axarr[hits_i][hit_i + 1].imshow(Image.open(hit.entity.get('filepath')))\n",
    "        axarr[hits_i][hit_i + 1].set_axis_off()\n",
    "        axarr[hits_i][hit_i + 1].set_title('Distance: ' + str(hit.distance))\n",
    "\n",
    "# Save the search result in a separate image file alongside your script.\n",
    "plt.savefig('reverse_search_result.png')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also see which images were classified as the most similar."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for hits_i, hits in enumerate(res):\n",
    "    for hit_it, hit in enumerate(hits):\n",
    "        print(hit.entity)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we want to clean up our Milvus instance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# cleanup\n",
    "default_server.stop()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hw_milvus",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
